import videocapture as video
import numpy as np
import cv2

import time 

from acllite_resource import AclLiteResource
from acllite_model import AclLiteModel
from acllite_imageproc import AclLiteImageProc
from acllite_image import AclLiteImage
from label import labels
from acllite_logger import log_error, log_info


class sampleYOLOV7(object):
    '''load the model, and do preprocess, infer, postprocess'''
    def __init__(self, model_path, model_width, model_height):
        self.model_path = model_path
        self.model_width = model_width
        self.model_height = model_height
        self.conf_thres = 0.25
        self.nms_thres = 0.45
        self.src_image = None

    def init_resource(self):
        # initial acl resource, create image processor, create model
        self._resource = AclLiteResource()
        self._resource.init()
    
        self._dvpp = AclLiteImageProc(self._resource) 
        self._model = AclLiteModel(self.model_path)

    def preprocess_image(self, frame):
        # resize frame by dvpp
        yuv_image = self._dvpp.jpegd(frame)
        self.resized_image = self._dvpp.resize(yuv_image, self.model_width, self.model_height)

    def preprocess_vis(self, frame):
        # resize frame, keep original image
        src_image = frame.byte_data_to_np_array().astype(np.uint8)
        self.src_image = cv2.cvtColor(src_image.reshape((frame.height*3//2, frame.width)), cv2.COLOR_YUV2RGB_NV21)

        self.resized_image = self._dvpp.resize(frame, self.model_width, self.model_height)

    def infer(self):
        # infer frame
        image_info = np.array([640, 640,
                            640, 640],
                            dtype=np.float32)
        self.result = self._model.execute([self.resized_image, image_info])
    
    def postprocess(self,path):
        box_num = self.result[1][0, 0]
        box_info = self.result[0].flatten()

        if self.src_image is None:
            src_image = cv2.imread(path)
        else:
            src_image = self.src_image
        height, width, _ = src_image.shape 
        scale_x = width / self.model_width
        scale_y = height / self.model_height

        colors = [0, 0, 255]
        text = ""
        # draw the boxes in original image
        for n in range(int(box_num)):
            ids = int(box_info[5 * int(box_num) + n])
            score = box_info[4 * int(box_num) + n]
            label = labels[ids] + ":" + str("%.2f" % score)
            top_left_x = box_info[0 * int(box_num) + n] * scale_x
            top_left_y = box_info[1 * int(box_num) + n] * scale_y
            bottom_right_x = box_info[2 * int(box_num) + n] * scale_x
            bottom_right_y = box_info[3 * int(box_num) + n] * scale_y
            cv2.rectangle(src_image, (int(top_left_x), int(top_left_y)),
                        (int(bottom_right_x), int(bottom_right_y)), colors)
            p3 = (max(int(top_left_x), 15), max(int(top_left_y), 15))
            position = [int(top_left_x), int(top_left_y), int(bottom_right_x), int(bottom_right_y)]
            cv2.putText(src_image, label, p3, cv2.FONT_ITALIC, 0.6, colors, 1)
            text += f'label:{label} {position}  '
        log_info(text)
        cv2.imshow('out', src_image)

    def release_resource(self):
        # release resource includes acl resource, data set and unload model
        del self.resized_image
        del self._model
        del self._dvpp
        del self._resource

def video_infer(video_path, model):
    cap = video.VideoCapture(video_path)
    while True:
        ret, frame = cap.read()
        if ret:
            print('cap read end! close subprocess cap read')
            break
        if frame is not None:
            print('start preprocess')
            model.preprocess_vis(frame)
            model.infer()
            model.postprocess(video_path)
            cv2.waitKey(1)
        else:
            log_info("read frame finish")
            break
    del cap

def image_infer(image_path, model):
    frame = AclLiteImage(image_path)
    model.preprocess_image(frame)
    model.infer()
    model.postprocess(image_path)
    cv2.waitKey(0)


if __name__ == '__main__':
    model_path = '../model/yolov5s_nms.om'
    model_width = 640
    model_height = 640

    model = sampleYOLOV7(model_path, model_width, model_height)
    model.init_resource()
    cv2.namedWindow('out', cv2.WINDOW_NORMAL)

    mode = "video"
    if mode == "image":
        path = "../data/dog1_1024_683.jpg"
        image_infer(path, model)
    elif mode == "video":
        path = "../data/test.h264"
        video_infer(path, model)
    else:
        print('input mode is incorrect.')

    cv2.destroyAllWindows()
    model.release_resource()
